Stable Diffusion 模型的加载
文字生成图片
使用不同的 scheduler 加速生成
如何批量生成
降低生成过程中的内存占用
使用微调模型生成
使用 huggingface 提供的 diffusers 包进行图片生成。
!pip install diffusers
!pip install transformers
使用 StableDiffusionPipeline 加载 v1-5 版本模型。
模型参数的数据类型为float32 ,为了降低模型所占空间大小,可以降低参数精度,例如通过torch_dtype=torch.float16 设置为半精度,并且降为半精度后,对模型性能的影响十分微小。
from diffusers import StableDiffusionPipeline
import torch
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
to("cuda") 将使用 GPU 运算来加速生成。
pipe = pipe.to("cuda")
prompt = "portrait photo of a old warrior chief"
image = pipe(prompt).images[0]
image

prompt = "Beach, sailboat, cruise ship"
image = pipe(prompt).images[0]
image

prompt = "Underwater City"
image = pipe(prompt).images[0]
image

以下查看当前 pipeline 支持的所有 scheduler
pipe.scheduler.compatibles
[diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler,
diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteScheduler,
diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
diffusers.schedulers.scheduling_ddim.DDIMScheduler,
diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
diffusers.schedulers.scheduling_pndm.PNDMScheduler,
diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
diffusers.schedulers.scheduling_dpmsolver_singlestep.DPMSolverSinglestepScheduler,
diffusers.schedulers.scheduling_deis_multistep.DEISMultistepScheduler,
diffusers.schedulers.scheduling_unipc_multistep.UniPCMultistepScheduler,
diffusers.schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteScheduler,
diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler,
diffusers.utils.dummy_torch_and_torchsde_objects.DPMSolverSDEScheduler,
diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler]
查看当前使用的 scheduler
pipe.scheduler
PNDMScheduler {
"_class_name": "PNDMScheduler",
"_diffusers_version": "0.19.3",
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": false,
"num_train_timesteps": 1000,
"prediction_type": "epsilon",
"set_alpha_to_one": false,
"skip_prk_steps": true,
"steps_offset": 1,
"timestep_spacing": "leading",
"trained_betas": null
}
换用 DPMSolverMultistepScheduler,并将 num_inference_steps 降低为 20,加速生成。
from diffusers import DPMSolverMultistepScheduler
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
prompt = "Underwater City"
pipe = pipe.to("cuda")
image = pipe(prompt, num_inference_steps=20).images[0]
image

生成初始参数
def get_inputs(batch_size=1):
generator = [torch.Generator("cuda").manual_seed(i) for i in range(batch_size)]
prompts = batch_size * [prompt]
num_inference_steps = 50
return {"prompt": prompts, "generator": generator, "num_inference_steps": num_inference_steps}
以网格的方式显示图片
from PIL import Image
def image_grid(imgs, rows=2, cols=2):
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
批量生成8张图片
images = pipe(**get_inputs(batch_size=8)).images
image_grid(images, rows=2, cols=4)

批量生成16张图片。
为避免同时生成,内存OOM,使用 enable_attention_slicing 方法打开串行生成
pipe.enable_attention_slicing()
images = pipe(**get_inputs(batch_size=16)).images
image_grid(images, rows=4, cols=4)

openjourney 是使用MidJourney生成的图片作为数据集微调后的模型。
类似的微调模型有很多,它们在不同的数据集上微调,可以生成不同类型的图片。
https://huggingface.co/spaces/huggingface-projects/diffusers-gallery 这里罗列了其中一部分。
from diffusers import StableDiffusionPipeline
import torch
model_id = "prompthero/openjourney"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
prompt = "retro serie of different cars with different colors and shapes, mdjrny-v4 style"
image = pipe(prompt).images[0]
image

prompt = "Underwater City, mdjrny-v4 style"
image = pipe(prompt).images[0]
image

images = pipe(**get_inputs(batch_size=8)).images
image_grid(images, rows=2, cols=4)
以上完整可运行代码,见 colab 平台:
https://github.com/erberry/ThinkML/blob/main/stable_diffusion_v1_5.ipynb
欢迎关注公众号,获取更多关于AI的分享。
code/s?__biz=Mzk0MDI2Nzc3Mw==&mid=2247484644&idx=1&sn=7e724908e5c16d74fc90f3f77930e517&chksm=c2e50576f5928c601344f1e694d01717b0d3d4e7d17074cc28a81600982ff62851f5eeec0fb3#rd